{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a213cfee",
   "metadata": {},
   "source": [
    "# Video search\n",
    "In this example we will be going over the code required to perform a video search. This example uses a Vgg model to extract video features that are then used with Milvus to build a system that can perform the searches. \n",
    "## Data\n",
    "\n",
    "This example uses 10 animated gifs as an example to build an end-to-end solution that uses image search video. Readers can use their own video files to build the system.\n",
    "\n",
    "Download location: https://drive.google.com/file/d/1hS4ANTQx9xNr9AByiLVeA1rEnxycdxtZ/view?usp=sharing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8102f770",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "| Python Packages | Docker Servers |\n",
    "| --------------- | -------------- |\n",
    "| pymilvus        | Milvus-1.1.0   |\n",
    "|\n",
    "\n",
    "We will assume that you have familiarity with libraries including Tensorflow."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c74d0fe1",
   "metadata": {},
   "source": [
    "## Up and Running\n",
    "\n",
    "\n",
    "### 1. Start Milvus Server\n",
    "\n",
    "```bash\n",
    "$  docker run -d --name milvus_cpu_1.0.0 --network my-net --ip 10.0.0.2 \\\n",
    "-p 19530:19530 \\\n",
    "-p 19121:19121 \\\n",
    "-v /home/$USER/milvus/db:/var/lib/milvus/db \\\n",
    "-v /home/$USER/milvus/conf:/var/lib/milvus/conf \\\n",
    "-v /home/$USER/milvus/logs:/var/lib/milvus/logs \\\n",
    "-v /home/$USER/milvus/wal:/var/lib/milvus/wal \\\n",
    "milvusdb/milvus:1.0.0-cpu-d030521-1ea92e\n",
    "```\n",
    "\n",
    "This demo uses Milvus 1.0. Refer to the [Install Milvus](https://milvus.io/docs/v1.0.0/milvus_docker-cpu.md) for how to install Milvus docker. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "047b80af",
   "metadata": {},
   "source": [
    "## Code Overview\n",
    "### Connecting to Servers\n",
    "\n",
    "We first start off by connecting to the servers. In this case the docker containers are running on localhost and the ports are the default ports. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "79d15825-1a18-45d0-9b11-559886c2e5fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Connectings to Milvus\n",
    "\n",
    "import milvus\n",
    "milv = milvus.Milvus(host = '127.0.0.1', port = 19530)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85801f40",
   "metadata": {},
   "source": [
    "### Building Collection and Setting Index\n",
    "\n",
    "The next step involves creating a collection. A collection in Milvus is similar to a table in a relational database, and is used for storing all the vectors. To create a collection, we first must select a name, the dimension of the vectors being stored within, the index_file_size, and metric_type. The index_file_size corresponds to how large each data segmet will be within the collection. More information on this can be found here. The metric_type is the distance formula being used to calculate similarity. In this example we are using the Euclidean distance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "db730ba7-618e-40c4-9f9f-a45e7dbf3125",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Status(code=0, message='Create collection successfully!')\n"
     ]
    }
   ],
   "source": [
    "#Creating collection\n",
    "\n",
    "import time\n",
    "\n",
    "collection_name = \"test_collection\"\n",
    "milv.drop_collection(collection_name) \n",
    "\n",
    "collection_param = {\n",
    "            'collection_name': collection_name,\n",
    "            'dimension': 512,\n",
    "            'index_file_size': 1024,  # optional\n",
    "            'metric_type': milvus.MetricType.L2  # optional\n",
    "            }\n",
    "\n",
    "status, ok = milv.has_collection(collection_name)\n",
    "\n",
    "if not ok:\n",
    "    status = milv.create_collection(collection_param)\n",
    "    print(status)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e8cf359",
   "metadata": {},
   "source": [
    "After creating the collection we want to assign it an index type. This can be done before or after inserting the data. When done before, indexes will be made as data comes in and fills the data segments. In this example we are using IVF_SQ8 which requires the 'nlist' parameter. Each index types carries its own parameters. More info about this param can be found here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ccbe53c8-4752-4872-914f-4d49096992d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(collection_name='test_collection', index_type=<IndexType: IVF_SQ8>, params={'nlist': 512})\n"
     ]
    }
   ],
   "source": [
    "#Indexing collection\n",
    "\n",
    "index_param = {\n",
    "    'nlist': 512\n",
    "}\n",
    "\n",
    "status = milv.create_index(collection_name, milvus.IndexType.IVF_SQ8, index_param)\n",
    "status, index = milv.get_index_info(collection_name)\n",
    "print(index)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc36cdc6",
   "metadata": {},
   "source": [
    "### Processing and Storing Videos\n",
    "\n",
    "In order to store the videos in Milvus, We first need to cut the frame of the video, here we choose the opencv method.. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "24664b23-94df-41ab-ab7a-73c12ba75d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import os\n",
    "import shutil\n",
    "save_path = \"/data1/lcl/test_video_bootcamp/frame_res/\"      # Path of save frame\n",
    "if not os.path.exists(save_path):\n",
    "    os.makedirs(save_path)  \n",
    "path = \"/data1/lcl/test_video_bootcamp/examle-gif-10/10-gif/\"    #Path of raw video \n",
    "\n",
    "filelist = os.listdir(path)     \n",
    "print(filelist)     \n",
    "for item in filelist:  \n",
    "    if item.endswith('.gif'):     # Write according to its own video file suffix, personal video files are in gif format\n",
    "        print(item)\n",
    "        try:\n",
    "            src = os.path.join(path, item)\n",
    "            vid_cap = cv2.VideoCapture(src)    \n",
    "            success, image = vid_cap.read()\n",
    "            count = 0\n",
    "            while success:\n",
    "                vid_cap.set(cv2.CAP_PROP_POS_MSEC, 1 * 1000 * count)   #The method of intercepting images Here is 1 second to intercept one Can change the parameters to set the interval of interception time\n",
    "                video_to_picture_path= os.path.join(save_path, item.split(\".\")[0])    # Naming of video folders\n",
    "                if not os.path.exists(video_to_picture_path):   #Create a folder corresponding to each video storage image\n",
    "                    os.makedirs(video_to_picture_path)\n",
    "                cv2.imwrite(video_to_picture_path+\"/\" + str(item.split(\".\")[0]) + \"#\" + str(count) + \".jpg\", image)       # Addresses of stored images and naming of images\n",
    "                success, image = vid_cap.read()\n",
    "                count += 1\n",
    "            print('Total frames: ', count)     #Print the number of intercepted images \n",
    "        except:\n",
    "            print(\"error\")\n",
    "            \n",
    "ALL_PIC = \"/data1/lcl/test_video_bootcamp/all_pic\"\n",
    "os.makedirs(ALL_PIC)\n",
    "\n",
    "for root, dirs, files in os.walk(save_path):\n",
    "    for file in files:\n",
    "        src_file = os.path.join(root, file)\n",
    "        shutil.copy(src_file, ALL_PIC)        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "928a3567",
   "metadata": {},
   "source": [
    "Next, we use the VGG model to extract the vectors from the images we got earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "7a689e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg\n",
    "from keras.preprocessing import image\n",
    "import numpy as np\n",
    "from diskcache import Cache\n",
    "from numpy import linalg as LA\n",
    "\n",
    "class VGGNet:\n",
    "    def __init__(self):\n",
    "        self.input_shape = (224, 224, 3)\n",
    "        self.weight = 'imagenet'\n",
    "        self.pooling = 'max'\n",
    "        self.model_vgg = VGG16(weights=self.weight,\n",
    "                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),\n",
    "                               pooling=self.pooling,\n",
    "                               include_top=False)\n",
    "        self.model_vgg.predict(np.zeros((1, 224, 224, 3)))\n",
    "\n",
    "    def vgg_extract_feat(self, img_path):\n",
    "        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))\n",
    "        img = image.img_to_array(img)\n",
    "        img = np.expand_dims(img, axis=0)\n",
    "        img = preprocess_input_vgg(img)\n",
    "        feat = self.model_vgg.predict(img)\n",
    "        norm_feat = feat[0] / LA.norm(feat[0])\n",
    "        norm_feat = [i.item() for i in norm_feat]\n",
    "        return norm_feat\n",
    "    \n",
    "def feature_extract(pic_path, model):\n",
    "    default_cache_dir=\"./tmp\"\n",
    "    cache = Cache(default_cache_dir)\n",
    "    feats = []\n",
    "    names = []\n",
    "    img_list = [os.path.join(pic_path, f) for f in os.listdir(pic_path) if (f.endswith('.jpg'))]\n",
    "    model = model\n",
    "    for i, img_path in enumerate(img_list):\n",
    "        norm_feat = model.vgg_extract_feat(img_path)\n",
    "        img_name = os.path.split(img_path)[1]\n",
    "        feats.append(norm_feat)\n",
    "        names.append(img_name.encode())\n",
    "        current = i+1\n",
    "        total = len(img_list)\n",
    "        cache['current'] = current\n",
    "        cache['total'] = total\n",
    "        print (\"extracting feature from image No. %d , %d images in total\" %(current, total))\n",
    "    return feats, names\n",
    "\n",
    "\n",
    "vectors, names = feature_extract(ALL_PIC, VGGNet())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "005c8168",
   "metadata": {},
   "source": [
    "Import the image vector into milvus and store the returned vector id into the database along with the image name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "5f2ac471-f13f-4bad-8077-f85256b6cfaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Status(code=0, message='Add vectors successfully!')\n"
     ]
    }
   ],
   "source": [
    "from diskcache import Cache\n",
    "\n",
    "default_cache_dir=\"./tmp\"\n",
    "cache = Cache(default_cache_dir)\n",
    "\n",
    "status, ids = milv.insert(collection_name=collection_name, records=vectors)\n",
    "print(status)\n",
    "\n",
    "for i in range(len(names)):\n",
    "    cache[ids[i]] = names[i]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b62bcbc",
   "metadata": {},
   "source": [
    "### Searching\n",
    "\n",
    "When searching for an image,we use the same vgg model to extract the vector of this image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "e2d58d45-775e-453d-a4e2-80ebfefda180",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg\n",
    "from keras.preprocessing import image\n",
    "import numpy as np\n",
    "from numpy import linalg as LA\n",
    "\n",
    "class VGGNet:\n",
    "    def __init__(self):\n",
    "        self.input_shape = (224, 224, 3)\n",
    "        self.weight = 'imagenet'\n",
    "        self.pooling = 'max'\n",
    "        self.model_vgg = VGG16(weights=self.weight,\n",
    "                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),\n",
    "                               pooling=self.pooling,\n",
    "                               include_top=False)\n",
    "        self.model_vgg.predict(np.zeros((1, 224, 224, 3)))\n",
    "\n",
    "    def vgg_extract_feat(self, img_path):\n",
    "        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))\n",
    "        img = image.img_to_array(img)\n",
    "        img = np.expand_dims(img, axis=0)\n",
    "        img = preprocess_input_vgg(img)\n",
    "        feat = self.model_vgg.predict(img)\n",
    "        norm_feat = feat[0] / LA.norm(feat[0])\n",
    "        norm_feat = [i.item() for i in norm_feat]\n",
    "        return norm_feat\n",
    "\n",
    "def search_feature_extract(pic_path, model):\n",
    "    feats = []\n",
    "    norm_feat = model.vgg_extract_feat(pic_path)\n",
    "    feats.append(norm_feat)\n",
    "    return feats\n",
    "\n",
    "embeddings = search_feature_extract(\"/data1/lcl/test_video_bootcamp/frame_res/tumblr_lhns1x9P9d1qc3h7bo1_400/tumblr_lhns1x9P9d1qc3h7bo1_400#1.jpg\",VGGNet())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68544d93",
   "metadata": {},
   "source": [
    "Then we can use these embeddings in a search. The search requires a few arguments. It needs the name of the collection, the vectors being searched for, how many closest vectors to be returned, and the parameters for the index, in this case nprobe. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "c22c1158",
   "metadata": {},
   "outputs": [],
   "source": [
    "search_sub_param = {\n",
    "        \"nprobe\": 16\n",
    "    }\n",
    "\n",
    "search_param = {\n",
    "    'collection_name': collection_name,\n",
    "    'query_records': embeddings,\n",
    "    'top_k': 10,\n",
    "    'params': search_sub_param,\n",
    "    }\n",
    "\n",
    "start = time.time()\n",
    "status, results = milv.search(**search_param)\n",
    "print (status)\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Search took a total of: \", end)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20b735af",
   "metadata": {},
   "source": [
    "The result of this search contains the IDs and corresponding distances of the top_k closes vectors. We can use the IDs in cache db to get the original video. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "ddaf2e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_name_from_ids(vids):\n",
    "    res = []\n",
    "    cache = Cache(default_cache_dir)\n",
    "    for i in vids:\n",
    "        if i in cache:\n",
    "            res.append(cache[i])\n",
    "    return res\n",
    "\n",
    "\n",
    "if status.OK():\n",
    "    vids = [x.id for x in results[0]]\n",
    "    res_name = [x.decode('utf-8') for x in query_name_from_ids(vids)]\n",
    "    res_video_name = []\n",
    "    for pic_name in res_name:\n",
    "        video_name = pic_name[0:pic_name.rfind('#', 1)] + \".gif\"\n",
    "        res_video_name.append(video_name)\n",
    "    print(res_video_name)\n",
    "    \n",
    "else:\n",
    "    print(\"Search Failed.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b1ac1c6",
   "metadata": {},
   "source": [
    "This is the basic way to conduct a video search."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
